백준 1067번

이동

Posted by ChoiCube84 on August 25, 2023 · 19 mins read

백준 1067번

오늘 풀어본 문제는 백준의 1067번 문제1이다. 문제 풀이에 사용한 언어는 C++ 이다.

solved.ac 기준 CLASS

CLASS 8

문제 정보

이 문제의 내용과 조건은 다음과 같다.

문제

$N$ 개의 수가 있는 $X$ 와 $Y$ 가 있다. 이때 $X$ 나 $Y$ 를 순환 이동시킬 수 있다. 순환 이동이란 마지막 원소를 제거하고 그 수를 맨 앞으로 다시 삽입하는 것을 말한다. 예를 들어, ${1, 2, 3}$ 을 순환 이동시키면 ${3, 1, 2}$ 가 될 것이고, ${3, 1, 2}$ 는 ${2, 3, 1}$ 이 된다. 순환 이동은 $0$ 번 또는 그 이상 할 수 있다. 이 모든 순환 이동을 한 후에 점수를 구하면 된다. 점수 $S$ 는 다음과 같이 구한다.

$S = X[0] \times Y[0] + X[1] \times Y[1] + \cdots + X[N-1] \times Y[N-1]$

이때 $S$ 를 최대로 하면 된다.

입력

첫째 줄에 $N$ 이 주어진다. 둘째 줄에는 $X$ 에 들어있는 $N$ 개의 수가 주어진다. 셋째 줄에는 $Y$ 에 있는 수가 모두 주어진다. $N$ 은 $60,000$ 보다 작거나 같은 자연수이고, $X$ 와 $Y$ 에 들어있는 모든 수는 $100$ 보다 작은 자연수 또는 $0$ 이다.

출력

첫째 줄에 $S$ 의 최댓값을 출력한다.

풀이과정

1번째 시도

이 문제는 FFT를 이용해야 하는 문제이다. 정확하게 말하자면, ‘convolution’ 연산을 하는 과정에서 FFT를 이용해야 한다. 나이브(naive) 한 방법으로 컨볼루션 연산을 하면 시간복잡도가 $O(n^2)$ 이 되는데, FFT를 활용하게 되면 시간 복잡도를 $O(n \log n)$으로 줄일 수 있는 것이다.

물론 이 문제는 단순히 두 벡터의 컨볼루션 연산을 한 뒤 결과를 내보내면 되는 것이 아니라, 각 벡터에 이동을 적용하여 연산한 결과들의 최댓값을 얻어내야 한다. 이는 첫 번째 벡터를 그대로 이어붙이고, 두 번째 벡터를 뒤집어서 컨볼루션을 진행하면 된다.

예를 들어 $(1,\ 2,\ 3,\ 4)$ 와 $(6,\ 7,\ 8,\ 5)$ 에 대한 점수를 구하기 위해서는 다음 식들을 계산하고 대소를 비교해야 한다.

  • $1 \times 6 + 2 \times 7 + 3 \times 8 + 4 \times 5$
  • $4 \times 6 + 1 \times 7 + 2 \times 8 + 3 \times 5$
  • $3 \times 6 + 4 \times 7 + 1 \times 8 + 2 \times 5$
  • $2 \times 6 + 3 \times 7 + 4 \times 8 + 1 \times 5$

여기서 첫 번째 수들은 이동을 시키지 않을 것을 볼 수 있는데, 두 번째 수들만 이동시킨 결과만 계산해도 첫 번째 수들을 임의로 이동시킨 결과 중 하나와 일치하기 때문에 냅둔 것이다.

이제 앞에서 내가 말한 방식대로 첫 번째 숫자들을 그대로 이은 뒤 두 번째 숫자들을 뒤집으면 $(1,\ 2,\ 3,\ 4,\ 1,\ 2,\ 3,\ 4)$ 과 $(5,\ 8,\ 7,\ 6)$ 가 되고, 이 둘을 convolution 한 결과의 각 항들은 다음과 같다.

  • $1 \times 5$
  • $1 \times 8 + 2 \times 5$
  • $1 \times 7 + 2 \times 8 + 3 \times 5$
  • $1 \times 6 + 2 \times 7 + 3 \times 8 + 4 \times 5$
  • $2 \times 6 + 3 \times 7 + 4 \times 8 + 1 \times 5$
  • $3 \times 6 + 4 \times 7 + 1 \times 8 + 2 \times 5$
  • $4 \times 6 + 1 \times 7 + 2 \times 8 + 3 \times 5$
  • $1 \times 6 + 2 \times 7 + 3 \times 8 + 4 \times 5$
  • $2 \times 6 + 3 \times 7 + 4 \times 8$
  • $3 \times 6 + 4 \times 7$
  • $4 \times 6$

여기서 우리가 필요한 식들이 모두 나오게 된다! 이제 이것들을 모조리 비교하여 최댓값을 구해내면 그것이 우리가 찾는 점수가 되는 것이다.

코드는 다음과 같이 작성하였다.

#include <bits/stdc++.h>

using namespace std;

double const_pi(void) {
	return std::atan(1) * 4;
}

void FFT(std::vector<std::complex<double>>& a, const std::complex<double>& w) {
	int n = a.size();
	if (n == 1) return;

	std::vector<std::complex<double>> a_even(n / 2), a_odd(n / 2);
	for (int i = 0; i < n / 2; i++) {
		a_even[i] = a[2 * i];
		a_odd[i] = a[2 * i + 1];
	}

	std::complex<double> w_sq = w * w;
	FFT(a_even, w_sq);
	FFT(a_odd, w_sq);

	std::complex<double> w_i = 1;
	for (int i = 0; i < n / 2; i++) {
		a[i] = a_even[i] + w_i * a_odd[i];
		a[i + n / 2] = a_even[i] - w_i * a_odd[i];
		w_i *= w;
	}
}

std::vector<std::complex<double>> convolution(std::vector<std::complex<double>> a, std::vector<std::complex<double>> b, bool getIntegerResult = false) {
	int n = 1;
	double pi = const_pi();

	while (n <= a.size() || n <= b.size()) {
		n <<= 1;
	}
	n <<= 1;

	a.resize(n);
	b.resize(n);

	std::vector<std::complex<double>> c(n);

	std::complex<double> w(cos(2 * pi / n), sin(2 * pi / n));

	FFT(a, w);
	FFT(b, w);

	for (int i = 0; i < n; i++) {
		c[i] = a[i] * b[i];
	}

	FFT(c, std::complex(w.real(), -w.imag()));

	for (int i = 0; i < n; i++) {
		c[i] /= std::complex<double>(n, 0);
		if (getIntegerResult) {
			c[i] = std::complex(round(c[i].real()), round(c[i].imag()));
		}
	}

	return c;
}

int main(void) {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);

	int N;
	cin >> N;

	vector<complex<double>> X(2 * N);
	vector<complex<double>> Y(N);

	for (int i = 0; i < N; i++) {
		int tempNum;
		cin >> tempNum;
		X[i] = X[i + N] = complex<double>(tempNum, 0);
	}

	for (int i = 0; i < N; i++) {
		int tempNum;
		cin >> tempNum;
		Y[N - i - 1] = complex<double>(tempNum, 0);
	}

	auto Z = convolution(X, Y, true);

	int result = 0;

	for (auto& i : Z) {
		cout << i << " ";
		result = max(result, static_cast<int>(i.real()));
	}
	cout << endl;
	cout << result;

	return 0;
}

실행 결과 ‘출력 초과’가 나왔다.

2번째 시도

제출하는 순간 아차 싶었고, 역시 문제가 발생했다. 테스트를 위해서 Z 벡터의 값을 출력했었는데, 이를 지우지 않고 제출한 것이다. 해결 방법은 당연하게도 테스트용 출력 코드들을 지우는 것이었다.

코드는 다음과 같이 작성하였다.

#include <bits/stdc++.h>

using namespace std;

double const_pi(void) {
	return std::atan(1) * 4;
}

void FFT(std::vector<std::complex<double>>& a, const std::complex<double>& w) {
	int n = a.size();
	if (n == 1) return;

	std::vector<std::complex<double>> a_even(n / 2), a_odd(n / 2);
	for (int i = 0; i < n / 2; i++) {
		a_even[i] = a[2 * i];
		a_odd[i] = a[2 * i + 1];
	}

	std::complex<double> w_sq = w * w;
	FFT(a_even, w_sq);
	FFT(a_odd, w_sq);

	std::complex<double> w_i = 1;
	for (int i = 0; i < n / 2; i++) {
		a[i] = a_even[i] + w_i * a_odd[i];
		a[i + n / 2] = a_even[i] - w_i * a_odd[i];
		w_i *= w;
	}
}

std::vector<std::complex<double>> convolution(std::vector<std::complex<double>> a, std::vector<std::complex<double>> b, bool getIntegerResult = false) {
	int n = 1;
	double pi = const_pi();

	while (n <= a.size() || n <= b.size()) {
		n <<= 1;
	}
	n <<= 1;

	a.resize(n);
	b.resize(n);

	std::vector<std::complex<double>> c(n);

	std::complex<double> w(cos(2 * pi / n), sin(2 * pi / n));

	FFT(a, w);
	FFT(b, w);

	for (int i = 0; i < n; i++) {
		c[i] = a[i] * b[i];
	}

	FFT(c, std::complex(w.real(), -w.imag()));

	for (int i = 0; i < n; i++) {
		c[i] /= std::complex<double>(n, 0);
		if (getIntegerResult) {
			c[i] = std::complex(round(c[i].real()), round(c[i].imag()));
		}
	}

	return c;
}

int main(void) {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);

	int N;
	cin >> N;

	vector<complex<double>> X(2 * N);
	vector<complex<double>> Y(N);

	for (int i = 0; i < N; i++) {
		int tempNum;
		cin >> tempNum;
		X[i] = X[i + N] = complex<double>(tempNum, 0);
	}

	for (int i = 0; i < N; i++) {
		int tempNum;
		cin >> tempNum;
		Y[N - i - 1] = complex<double>(tempNum, 0);
	}

	auto Z = convolution(X, Y, true);

	int result = 0;

	for (auto& i : Z) {
		result = max(result, static_cast<int>(i.real()));
	}

	cout << result;

	return 0;
}

그러자 모든 테스트 케이스를 통과하고 정답이 나오는 것을 확인할 수 있었다.

마무리

이젠 하다하다 CLASS 8 문제를 풀게되었다. 예상했겠지만, 세미나에서 배웠던 내용으로 FFT를 활용한 문제였다.

FFT를 써서 문제를 푼 것까지는 좋은데, FFT가 뭐하는 건지 솔직히 아직 제대로 이해하지 못했다. 지난 학기에 응용복소함수론 강의를 듣긴 했었는데, 아마 그 내용과 관련이 있나 싶기도 하다. 조만간 FFT는 제대로 공부해봐야 할 것 같다.

그리고 구현해둔 FFT 코드는 ‘자료구조’ 는 아니긴 하나, 이후에 유용하게 사용할 수 있을 것 같아 깃험 레포지토리2에 저장해두었다. 필요한 사람은 참고하면 된다.

오늘의 PS는 여기까지!


1: https://www.acmicpc.net/problem/1067
2: https://github.com/ChoiCube84/baekjoon-solutions/blob/main/C%2B%2B/custom_data_structures/custom_algorithms.hpp